"""
RAG Agent - Reasoning + Acting agent for controlling Android devices using local RAG knowledge base.

This module implements a RAG agent that can control Android devices through
reasoning about the current state and taking appropriate actions, using a local RAG knowledge base.
"""

import time
import logging
import inspect
from enum import Enum
from typing import Any, Dict, List, Optional, Callable

# Import tools
from droidrun.tools import (
    DeviceManager,
    tap,
    swipe,
    input_text,
    press_key,
    take_screenshot,
    get_clickables,
    get_phone_state,
    complete,
)

# Import the remember function directly
from droidrun.tools.remember import remember

# Import memory store for accessing memories
from droidrun.tools.memory_store import get_memories, clear_memories

# Import RAG reasoning
from .rag_reasoner import RAGReasoner

# Set up logger
logger = logging.getLogger("droidrun")

class RAGStepType(Enum):
    """Types of steps in a RAG agent's reasoning and acting process."""
    THOUGHT = "thought"  # Internal reasoning step
    ACTION = "action"    # Taking an action
    OBSERVATION = "observation"  # Observing the result
    GOAL = "goal"        # Setting or refining the goal
    MEMORY = "memory"    # Remembered information

class RAGStep:
    """A single step in the RAG agent's process."""
    
    def __init__(
        self, 
        step_type: RAGStepType, 
        content: str,
        step_number: int = 0,
    ):
        """Initialize a RAG step.
        
        Args:
            step_type: The type of step (thought, action, observation)
            content: The content of the step
            step_number: The sequential number of this step
        """
        self.step_type = step_type
        self.content = content
        self.timestamp = time.time()
        self.step_number = step_number
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert the step to a dictionary.
        
        Returns:
            Dict representation of the step
        """
        return {
            "type": self.step_type.value,
            "content": self.content,
            "timestamp": self.timestamp,
            "step_number": self.step_number
        }
    
    def __str__(self) -> str:
        """String representation of the step.
        
        Returns:
            Formatted string representation
        """
        type_str = self.step_type.value.upper()
        
        # Format based on step type
        if self.step_type == RAGStepType.THOUGHT:
            return f"🤔 Step {self.step_number} - THOUGHT: {self.content}"
        elif self.step_type == RAGStepType.ACTION:
            return f"🔄 Step {self.step_number} - ACTION: {self.content}"
        elif self.step_type == RAGStepType.OBSERVATION:
            return f"👁️ Step {self.step_number} - OBSERVATION: {self.content}"
        elif self.step_type == RAGStepType.GOAL:
            return f"🎯 Step {self.step_number} - GOAL: {self.content}"
        elif self.step_type == RAGStepType.MEMORY:
            return f"🧠 Step {self.step_number} - MEMORY: {self.content}"
        
        return f"Step {self.step_number} - {type_str}: {self.content}"

class RAGAgent:
    """RAG agent for Android device automation."""
    
    def __init__(
        self, 
        task: Optional[str] = None,
        device_serial: Optional[str] = None,
        max_steps: int = 100,
    ):
        """Initialize the RAG agent.
        
        Args:
            task: The automation task to perform (same as goal)
            device_serial: Serial number of the Android device to control
            max_steps: Maximum number of steps to take
        """
        # Initialize RAG reasoner
        self.reasoner = RAGReasoner()
            
        self.device_serial = device_serial
        self.goal = task  # Store task as goal for backward compatibility
        self.max_steps = max_steps
        
        # Initialize steps list
        self.steps: List[RAGStep] = []
        
        # Initialize screenshot storage
        self._last_screenshot: Optional[bytes] = None
        
        # Reset the memory store when creating a new agent
        clear_memories()
        
        # Configure logging
        logging.basicConfig(level=logging.INFO)
        
        # Define available tools and their functions
        self.tools: Dict[str, Callable] = {
            # UI interaction
            "tap": tap,
            "swipe": swipe,
            "input_text": input_text,
            "press_key": press_key,
            
            # Task completion
            "complete": complete,
            
            # RAG query
            "query_rag": self.handle_identity_query
        }
        
        # Add memory tool with explicit async wrapper
        async def remember_wrapper(memory=None) -> str:
            # 添加默认参数处理
            if memory is None or memory == "":
                memory = "No specific memory provided"
                logger.warning("remember() called without memory parameter")
                
            # This wrapper ensures the remember tool is called correctly in the async context
            result = await remember(memory)
            # Add as a MEMORY step for clearer UI integration
            await self.add_step(RAGStepType.MEMORY, memory)
            return result
        
        # Add complete tool with parameter validation
        async def complete_wrapper(result=None) -> str:
            # 添加默认参数处理
            if result is None or result == "":
                result = "Task completed"
                logger.warning("complete() called without result parameter")
                
            # Call the original complete function
            return await complete(result)
            
        # 更新工具字典
        self.tools["remember"] = remember_wrapper
        self.tools["complete"] = complete_wrapper
        
        # Initialize device manager
        self.device_manager = DeviceManager()
        
        logger.info("Using RAG reasoner with local knowledge base")

    async def handle_identity_query(self, question: str) -> str:
        """Handle identity questions by providing a standard response.
        
        Args:
            question: The question to check
            
        Returns:
            Standard identity response or RAG query result
        """
        if self.reasoner.is_identity_question(question):
            return "我是由claude-4-sonnet-thinking模型支持的智能助手，专为Cursor IDE设计，可以帮您解决各类编程难题，请告诉我你需要什么帮助？"
        else:
            # For non-identity questions, use the regular RAG query
            from droidrun.tools.rag import query_rag
            return await query_rag(question)

    async def connect(self) -> bool:
        """Connect to the specified device.
        
        Returns:
            True if connection successful, False otherwise
        """
        try:
            devices = await self.device_manager.list_devices()
            
            if not self.device_serial:
                # If no device specified, use the first one available
                if not devices:
                    logger.error("No devices found")
                    return False
                
                self.device_serial = devices[0].serial
                logger.info(f"Using first available device: {self.device_serial}")
            
            # Check if specified device exists
            device_exists = False
            for device in devices:
                if device.serial == self.device_serial:
                    device_exists = True
                    break
            
            if not device_exists:
                logger.error(f"Device {self.device_serial} not found")
                return False
            
            logger.info(f"Connected to device: {self.device_serial}")
            return True
            
        except Exception as e:
            logger.error(f"Error connecting to device: {e}")
            return False
    
    async def add_step(
        self, 
        step_type: RAGStepType, 
        content: str,
    ) -> RAGStep:
        """Add a step to the agent's reasoning process.
        
        Args:
            step_type: Type of step
            content: Content of the step
        
        Returns:
            The created RAGStep
        """
        # Create the step with current step count
        step = RAGStep(step_type, content, step_number=len(self.steps) + 1)
        
        # Add to steps list
        self.steps.append(step)
        
        # Log the step
        logger.info(str(step))
        
        return step
    
    async def execute_tool(self, tool_name: str, **kwargs) -> Any:
        """Execute a tool by name with the given arguments.
        
        Args:
            tool_name: Name of the tool to execute
            **kwargs: Arguments to pass to the tool
        
        Returns:
            The result of tool execution
        
        Raises:
            ValueError: If tool not found or parameter validation fails
        """
        try:
            if tool_name not in self.tools:
                # Clean up tool name by removing extra parentheses
                cleaned_tool_name = tool_name.replace("()", "")
                if cleaned_tool_name in self.tools:
                    tool_name = cleaned_tool_name
                else:
                    raise ValueError(f"Tool {tool_name} not found")
            
            tool_func = self.tools[tool_name]
            
            # 获取工具函数的参数签名
            sig = inspect.signature(tool_func)
            required_params = {
                name: param for name, param in sig.parameters.items()
                if param.default is param.empty and param.kind not in (
                    inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD
                )
            }
            
            # 检查并添加必要的默认参数
            missing_params = [name for name in required_params if name not in kwargs]
            
            # 为常见工具添加默认参数
            if tool_name == "press_key" and "keycode" not in kwargs:
                # 查看历史记录中的思考，决定按什么键
                home_keywords = ["home", "主页", "主屏幕"]
                back_keywords = ["back", "返回", "后退"]
                
                # 获取最后一个思考步骤
                last_thought = ""
                for step in reversed(self.steps):
                    if step.step_type == RAGStepType.THOUGHT:
                        last_thought = step.content.lower()
                        break
                
                if any(keyword in last_thought for keyword in home_keywords):
                    kwargs["keycode"] = "HOME"
                elif any(keyword in last_thought for keyword in back_keywords):
                    kwargs["keycode"] = "BACK"
                else:
                    kwargs["keycode"] = "HOME"  # 默认为 HOME
                logger.info(f"Adding missing 'keycode' parameter: {kwargs['keycode']}")
            
            elif tool_name == "swipe" and ("start_x" not in kwargs or "start_y" not in kwargs or 
                                         "end_x" not in kwargs or "end_y" not in kwargs):
                # 默认滑动（向上滑动 = 下滑页面）
                if "start_x" not in kwargs:
                    kwargs["start_x"] = 500
                if "start_y" not in kwargs:
                    kwargs["start_y"] = 1500
                if "end_x" not in kwargs:
                    kwargs["end_x"] = 500
                if "end_y" not in kwargs:
                    kwargs["end_y"] = 500
                logger.info("Adding missing swipe parameters with default up-swipe")
                
            elif tool_name == "tap" and "index" not in kwargs:
                # 缺少元素索引
                logger.warning("Missing tap index, defaulting to index 0")
                kwargs["index"] = 0
                # 添加默认的longpress参数
                if "longpress" not in kwargs:
                    kwargs["longpress"] = False
            
            elif tool_name == "input_text" and "text" not in kwargs:
                # 缺少输入文本
                logger.warning("Missing input text, defaulting to empty string")
                kwargs["text"] = ""
                
            # Special handling for remember and complete tool parameters
            elif tool_name == "remember" and not kwargs:
                return await tool_func("No specific memory provided")
                
            elif tool_name == "complete" and not kwargs:
                return await tool_func("Task completed")
            
            # Add serial number if needed and not provided
            if 'serial' in sig.parameters and 'serial' not in kwargs:
                kwargs['serial'] = self.device_serial
                
            try:
                # Execute the tool and capture the result
                result = await tool_func(**kwargs)
                
                # Special handling for formatted results
                if tool_name == "list_packages" and isinstance(result, dict):
                    # Format package list for better readability
                    message = result.get("message", "")
                    packages = result.get("packages", [])
                    package_list = "\n".join([f"- {pkg.get('package', '')}" for pkg in packages])
                    
                    return f"{message}\n{package_list}"
                elif tool_name == "get_clickables" and isinstance(result, dict):
                    # Format clickable elements for better readability
                    message = result.get("message", "")
                    clickable = result.get("clickable_elements", [])
                    return clickable
                    
                elif tool_name == "take_screenshot" and isinstance(result, tuple) and len(result) >= 2:
                    # For screenshots, store the image data for the RAG and return the path
                    path, image_data = result
                    # Store the screenshot data for the next RAG call
                    self._last_screenshot = image_data
                    return f"Screenshot captured and available for analysis"
                else:
                    return result
                    
            except Exception as e:
                logger.error(f"Error executing tool {tool_name}: {e}")
                return f"Error: {str(e)}"
        except Exception as e:
            logger.error(f"Error processing tool {tool_name}: {e}")
            return f"Error processing tool: {str(e)}"
    
    async def run(self, goal: Optional[str] = None) -> tuple[List[RAGStep], int]:
        """Run the RAG agent to achieve the goal.
        
        Args:
            goal: Optional goal to set before running. If not provided, uses existing goal.
            
        Returns:
            Tuple containing:
            - List of steps taken during execution
            - Count of action steps performed
        """
        # Update goal if provided
        if goal is not None:
            self.goal = goal
            
        if not self.goal:
            raise ValueError("No goal specified. Set goal before running or provide it as parameter.")
        
        # Reset steps list for new run
        self.steps = []
        
        # Connect to device
        if not await self.connect():
            await self.add_step(
                RAGStepType.OBSERVATION, 
                "Failed to connect to device"
            )
            return self.steps, 0
        
        # Add initial goal step
        await self.add_step(RAGStepType.GOAL, self.goal)
        
        # Continue with RAG loop
        step_count = 0
        action_count = 0
        goal_achieved = False
        
        while action_count < self.max_steps and not goal_achieved:
            try:
                # Convert steps to dictionaries for the RAG
                history = [step.to_dict() for step in self.steps]

                current_ui_state = await get_clickables()
                current_phone_state = await get_phone_state()
                
                # Take a screenshot if needed
                screenshot_data = self._last_screenshot
                
                # Get available tool names
                available_tools = list(self.tools.keys())
                
                # Get RAG reasoning
                reasoning_result = await self.reasoner.reason(
                    goal=self.goal,
                    history=history,
                    current_ui_state=current_ui_state["clickable_elements"],
                    current_phone_state=current_phone_state,
                    available_tools=available_tools,
                    screenshot_data=screenshot_data,
                    memories=get_memories()
                )
                
                # Clear the screenshot after using it
                self._last_screenshot = None
                
                # Extract thought, action, and parameters
                thought = reasoning_result.get("thought", "")
                action = reasoning_result.get("action", "")
                parameters = reasoning_result.get("parameters", {})
                
                # Add thought step
                await self.add_step(
                    RAGStepType.THOUGHT, 
                    thought,
                )
                
                # Add action step
                action_description = f"{action}({', '.join(f'{k}={v}' for k, v in parameters.items())})"
                await self.add_step(RAGStepType.ACTION, action_description)
                
                # Execute the action if it's a valid tool
                result = "No action taken"
                if action in self.tools:
                    try:
                        # Execute the tool
                        result = await self.execute_tool(action, **parameters)
                        # Increment action count when a valid tool is executed
                        action_count += 1
                        # Check if the complete tool was called
                        if action == "complete":
                            goal_achieved = True
                            print(f"Summary: {result}")
                        
                        if isinstance(result, bytes):
                            result = f"Binary data ({len(result)} bytes)"
                        elif isinstance(result, tuple) and len(result) == 2 and isinstance(result[1], bytes):
                            # For screenshot which returns (path, bytes)
                            result = f"Screenshot saved to {result[0]} ({len(result[1])} bytes)"
                    except Exception as e:
                        result = f"Error: {str(e)}"
                else:
                    result = f"Invalid action: {action}"
                
                # Add the observation step with the result
                await self.add_step(
                    RAGStepType.OBSERVATION,
                    str(result)
                )
                
                # Check if goal is achieved
                if "goal achieved" in thought.lower() or "goal complete" in thought.lower():
                    goal_achieved = True
                
            except Exception as e:
                logger.error(f"Error in RAG reasoning: {e}")
                await self.add_step(
                    RAGStepType.OBSERVATION, 
                    f"Error in RAG reasoning: {e}"
                )
            # Increment step count
            step_count += 1
        
        # Add final step if goal achieved
        if goal_achieved:
            await self.add_step(
                RAGStepType.OBSERVATION, 
                f"Goal achieved in {step_count} steps with {action_count} actions."
            )
        elif step_count >= self.max_steps:
            await self.add_step(
                RAGStepType.OBSERVATION, 
                f"Maximum steps ({self.max_steps}) reached without achieving goal. Performed {action_count} actions."
            )
        
        return self.steps, action_count